import os
import cv2
import sys
import time
import glob
import json
import tempfile
import argparse
import threading
from queue import Queue
from decord import VideoReader, cpu

try:
    import grpc
    # NOTE: generate from proto file using:
    # cd jetson/
    # python -m grpc_tools.protoc -I . --python_out=. \
    # --grpc_python_out=. eval_server.proto
    import eval_server_pb2
    import eval_server_pb2_grpc
except Exception:
    print('Cannot find eval_server_pb2.py and eval_server_pb2_grpc.py')

# Map inputs type to model directory
# Model inputs type: (eval_program, model_dir)
Config = {
    'frames': ('eval_servers/eval_r2plus1d', 'baseline_r2plus1d'),
    'visual_token': ('eval_servers/eval_v3', 'xiaodu_hi_v3.3', 0.9),
    'instance': ('eval_servers/eval_v3', 'attn_instance'),
    'without_inst_fm': ('eval_servers/eval_v3', 'attn_without_inst_fm', 0.85),
    'without_inst_cls': ('eval_servers/eval_v3', 'attn_without_inst_cls', 0.65),
    'without_inst_pos': ('eval_servers/eval_v3', 'attn_without_inst_pos', 0.8),
    'inst_crop': ('eval_servers/eval_v3', 'attn_inst_crop')
    # 'visual_token': ('eval_servers/eval_v3', 'xiaodu_hi_v3.3_vis')
}

URL = '10.12.121.6'


def parse_args():
    parser = argparse.ArgumentParser(
        description='Parallel evaluation for eval.cpp.')
    parser.add_argument(
        '--eval-mode', type=str, default='eval_txt',
        help='Evaluation mode: eval_txt, eval_video.')
    parser.add_argument(
        '--logdir', type=str, required=True, help='Log directory.')
    parser.add_argument(
        '--model-dir', type=str, default=None,
        help='Directory to saved model.')
    parser.add_argument(
        '--gpus', type=str, default='0',
        help='GPUs to use, e.g. "0,1,2,3,4".')

    eval_txt_group = parser.add_argument_group('eval_txt')
    eval_txt_group.add_argument(
        '--eval-txt', '-e', type=str, default='../data/final_eval.txt',
        help='Path to eval data txt which is generated by '
        'train_attention_controller.py.')
    eval_txt_group.add_argument(
        '--inputs-type', '-i', type=str, default='visual_token',
        help='Inputs type of attention controller.')
    eval_txt_group.add_argument(
        '--workers', '-w', type=int, default=4,
        help='Number of parallel workers.')
    eval_txt_group.add_argument(
        '-th', type=float, default=0.8,
        help='Default threshold for trigger, used when no threshold in `Config`.')

    eval_video_group = parser.add_argument_group('eval_video')
    eval_video_group.add_argument(
        '--data-dir', type=str, default='../data/eval_videos',
        help='Folder of the evaluation videos and annotations.')
    eval_video_group.add_argument(
        '--num-data-threads', type=int, default=6,
        help='Number of data threads to read videos.')
    eval_video_group.add_argument(
        '--interval', type=int, default=250,
        help='Sampling interval in milliseconds.')
    eval_video_group.add_argument(
        '--occupy', type=int, default=10000,
        help='The milliseconds that robot is occupied after one action.')

    return parser.parse_args()

# ==============================
# "eval_txt" mode
# ==============================


def split_eval_txt(txt, workers):
    splits = [[] for _ in range(workers)]
    with open(txt, 'r') as f:
        idx = 0
        for line in f.readlines():
            splits[idx % workers].append(line.strip())
            idx += 1

    paths = []
    for split in splits:
        fd, path = tempfile.mkstemp(suffix='_eval')
        with os.fdopen(fd, 'w') as tmp:
            for line in split:
                tmp.write(line + '\n')
        paths.append(path)

    return paths


def run_eval_txt(idx, path, gpu, args):
    model_dir = args.model_dir if args.model_dir is not None else \
        Config[args.inputs_type][1]
    exe_path = Config[args.inputs_type][0]
    logdir = os.path.join(args.logdir, 'worker_{}'.format(idx))
    os.makedirs(logdir, exist_ok=True)
    th = args.th if len(Config[args.inputs_type]) < 3 else \
        Config[args.inputs_type][2]
    cmd = 'CUDA_VISIBLE_DEVICES={} {} -gpu '.format(gpu, exe_path)
    cmd += '-dirname {} -th {} -logdir {} -dataTxt {} -inputsType {}'.format(
        model_dir, th, logdir, path, args.inputs_type)
    os.system(cmd)


def merge_metrics(args):
    counts, nlls = [], []
    trigger_tps, trigger_fps, trigger_fns = [], [], []
    nullAct_tps, nullAct_fps, nullAct_fns = [], [], []
    for txt in glob.glob(os.path.join(args.logdir, 'worker_*', 'metric.txt')):
        with open(txt, 'r') as f:
            for line in f.readlines():
                if line.startswith('triggerTP'):
                    trigger_tps.append(int(line.strip().split(' ')[1]))
                elif line.startswith('triggerFP'):
                    trigger_fps.append(int(line.strip().split(' ')[1]))
                elif line.startswith('triggerFN'):
                    trigger_fns.append(int(line.strip().split(' ')[1]))
                elif line.startswith('nullActTP'):
                    nullAct_tps.append(int(line.strip().split(' ')[1]))
                elif line.startswith('nullActFP'):
                    nullAct_fps.append(int(line.strip().split(' ')[1]))
                elif line.startswith('nullActFN'):
                    nullAct_fns.append(int(line.strip().split(' ')[1]))
                elif line.startswith('eeID'):
                    counts.append(int(line.strip().split(' ')[1]))
                elif line.startswith('actNLL'):
                    nlls.append(float(line.strip().split(' ')[1]))

    eps = 1e-6
    trigger_precision = (sum(trigger_tps) + eps) / \
        (sum(trigger_tps) + sum(trigger_fps) + eps)
    trigger_recall = (sum(trigger_tps) + eps) / \
        (sum(trigger_tps) + sum(trigger_fns) + eps)
    nullAct_precision = (sum(nullAct_tps) + eps) / \
        (sum(nullAct_tps) + sum(nullAct_fps) + eps)
    nullAct_recall = (sum(nullAct_tps) + eps) / \
        (sum(nullAct_tps) + sum(nullAct_fns) + eps)

    avg_nll = 0
    for c, n in zip(counts, nlls):
        avg_nll += c * n
    avg_nll = avg_nll / sum(counts)

    with open(os.path.join(args.logdir, 'metric.txt'), 'w') as f:
        f.write('#Trigger\n')
        f.write('TH Precision Recall\n')
        f.write('{} {:.4f} {:.4f}\n'.format(
            args.th, trigger_precision, trigger_recall))

        f.write('#NullAct\n')
        f.write('Precision Recall\n')
        f.write('{:.4f} {:.4f}\n'.format(
            nullAct_precision, nullAct_recall))

        f.write('\n\n')
        f.write('triggerTP: {}\n'.format(sum(trigger_tps)))
        f.write('triggerFP: {}\n'.format(sum(trigger_fps)))
        f.write('triggerFN: {}\n\n'.format(sum(trigger_fns)))
        f.write('nullActTP: {}\n'.format(sum(nullAct_tps)))
        f.write('nullActFP: {}\n'.format(sum(nullAct_fps)))
        f.write('nullActFN: {}\n\n'.format(sum(nullAct_fns)))
        f.write('actNLL: {:.4f}\n'.format(avg_nll))


def parallel_eval_from_txt(args, paths, gpus):
    worker_threads = []
    for idx in range(args.workers):
        w = threading.Thread(
            target=run_eval_txt,
            args=(idx, paths[idx], gpus[idx % len(gpus)], args))
        w.start()
        time.sleep(5)
        worker_threads.append(w)

    for w in worker_threads:
        w.join()

    merge_metrics(args)
    for p in paths:
        os.remove(p)

# ==============================
# "eval_video" mode
# ==============================


def timestamp_to_ms(time_str):
    """
    Convert timestamp string to milliseconds.

    Note that `time_str` is in format 'hour:minute:second.10-milliseconds'
    """
    hour, minute, sec_ten_ms = time_str.split(':')
    sec, ten_ms = sec_ten_ms.split('.')
    hour, minute, sec, ms = int(hour), int(minute), int(sec), int(ten_ms) * 10
    total_ms = hour * 3600 * 1000 + minute * 60 * 1000 + sec * 1000 + ms
    return total_ms


def ms_to_timestamp(ms):
    """
    Convert milliseconds to timestamp string.
    """
    hour = ms // (3600 * 1000)
    ms = ms - hour * (3600 * 1000)
    minute = ms // (60 * 1000)
    ms = ms - minute * (60 * 1000)
    sec = ms // 1000
    ms = ms - sec * 1000
    ten_ms = ms // 10
    timestamp_str = '%s:%s:%s.%s' % (int(hour), int(minute), int(sec),
                                     int(ten_ms))
    return timestamp_str


def read_anno_txt(txt):
    annos = []
    with open(txt, 'r') as f:
        for line in f.readlines():
            anno = json.loads(line.strip())
            anno['Time'] = timestamp_to_ms(anno['Time'])
            if anno.get('PredType', 'TP') != 'TP':
                continue
            annos.append(anno)

    annos = sorted(annos, key=lambda a: a['Time'])
    return annos


def calculate_avg_precison(precisions):
    # precisions in order of increasing threshold
    precisions = [p for p in reversed(precisions)]
    interp_precisions = []
    for i in range(len(precisions)):
        interp_precisions.append(max(precisions[i:]))
    return sum(interp_precisions) / len(interp_precisions)


def calculate_avg_recall(recalls):
    # recalls in order of increasing threshold
    recalls = [r for r in recalls]
    interp_recalls = []
    for i in range(len(recalls)):
        interp_recalls.append(max(recalls[i:]))
    return sum(interp_recalls) / len(interp_recalls)


def run_eval_server(idx, gpu, args, first_port=8010):
    port = first_port + idx
    inputs_type = list(Config.keys())[idx]
    model_dir = args.model_dir if args.model_dir is not None else \
        Config[inputs_type][1]
    exe_path = Config[inputs_type][0]

    extra_args = ''
    th = args.th if len(Config[args.inputs_type]) < 3 else \
        Config[args.inputs_type][2]
    if os.path.basename(exe_path) == 'eval_v3':
        if not inputs_type.startswith('inst_crop'):
            extra_args = '-salutation '
        extra_args += '-ensemble -inputsType {} -th {}'.format(inputs_type, th)

    cmd = 'CUDA_VISIBLE_DEVICES={} {} -gpu '.format(gpu, exe_path)
    cmd += '{} -dirname {} -port {}'.format(extra_args, model_dir, port)
    print(cmd)
    os.system(cmd)


def run_data_worker(idx, args, data_queue, safe_gap=10000, nframes=10):
    assert args.interval * nframes < safe_gap
    videos = [i for i in os.listdir(args.data_dir) if i.endswith('.mp4')]
    videos = [os.path.join(args.data_dir, i) for i in videos]

    sub_videos = []
    for i, v in enumerate(videos):
        if i % args.num_data_threads == idx:
            sub_videos.append(v)

    for video in sub_videos:
        txt = glob.glob(video + '_*.txt')[0]
        annos = read_anno_txt(txt)
        vr = VideoReader(video, ctx=cpu(idx))
        vid = os.path.basename(video)
        frame_ts_table = [int(vr.get_frame_timestamp(i)[1] * 1000)
                          for i in range(len(vr))]

        frame_ids = []
        t = fid = aid = 0  # three trace pointers

        # O(N+M), N=len(intervals), M=len(vr)
        while aid < len(annos) and \
                t < max(args.interval // 2, annos[aid]['Time'] - safe_gap):
            # Extract negative example
            # NOTE: for simplicity, this implementation the segment
            # from last positive anno to the video end
            while fid < len(vr) and frame_ts_table[fid] < t:
                fid += 1

            if len(frame_ids) < nframes:
                frame_ids.append(fid)
            else:
                frame_ids.pop(0)
                frame_ids.append(fid)

            if len(frame_ids) == nframes:
                frames = [cv2.resize(img[:, :, ::-1], (640, 360))
                          for img in list(vr.get_batch(frame_ids).asnumpy())]
                data_queue.put((frames, 0, vid, t))

            t += args.interval

            if t >= annos[aid]['Time'] - safe_gap:
                # Extract positive example
                frame_ids.clear()
                for i in range(nframes - 1, -1, -1):
                    # i = 9, 8, ..., 0 when nframes = 10
                    t = annos[aid]['Time'] - args.interval * i
                    while frame_ts_table[fid] < t:
                        fid += 1
                    frame_ids.append(fid)

                frames = [cv2.resize(img[:, :, ::-1], (640, 360))
                          for img in list(vr.get_batch(frame_ids).asnumpy())]
                data_queue.put((frames, 1, vid, t))
                frame_ids.clear()

                t = annos[aid]['Time'] + safe_gap
                while aid < len(annos) and t >= annos[aid]['Time'] - safe_gap:
                    aid += 1


def run_eval_worker(idx, args, data_queue, first_port=8010):
    tp = fp = fn = 0  # for null act
    tp_ens = fp_ens = fn_ens = 0  # for ensemble
    tp_dict, fp_dict, fn_dict = dict(), dict(), dict()  # for trigger
    th_lst, th = [], 0.05
    while th < 1.0:
        th_lst.append(th)
        tp_dict[th], fp_dict[th], fn_dict[th] = 0, 0, 0
        th += 0.05

    inputs_type = list(Config.keys())[idx]
    metric_log = os.path.join(args.logdir, '{}.txt'.format(inputs_type))
    pred_dir = os.path.join(args.logdir, inputs_type)
    os.makedirs(pred_dir, exist_ok=True)

    channel = grpc.insecure_channel('{}:{}'.format(URL, first_port + idx))
    stub = eval_server_pb2_grpc.EvalServerStub(channel)
    last_t = dict()
    while True:
        frames, label, vid, t = data_queue.get()
        if inputs_type == 'frames':
            frames = frames[-8:]

        if vid in last_t and t - last_t[vid] < args.occupy:
            continue
        elif vid in last_t and t - last_t[vid] >= args.occupy:
            last_t[vid] = t
        elif vid not in last_t:
            last_t[vid] = t

        frames_bytes = None
        for frame in frames:
            if frames_bytes is None:
                frames_bytes = bytes(frame)
            else:
                frames_bytes += bytes(frame)

        req = eval_server_pb2.EvalRequest(
            nframe=len(frames), frames=frames_bytes)
        res = stub.infer(req)

        pred = json.loads(res.response)
        pred['Time'] = ms_to_timestamp(t)
        pred['ActScore'] = round(res.response_score, 4)
        pred['Trigger'] = round(res.trigger_pred, 4)
        pred['NullActScore'] = round(res.nullact_score, 4)
        pred['NullActID'] = res.nullact_id

        # NullAct
        if label == 1 and res.nullact_id != 0:
            tp += 1
            pred['PredType'] = 'TP'
        elif label == 0 and res.nullact_id != 0:
            fp += 1
            pred['PredType'] = 'FP'
        elif label == 1 and res.nullact_id == 0:
            fn += 1
            pred['PredType'] = 'FN'

        nullact_txt = os.path.join(pred_dir, '{}_nullact.txt'.format(vid))
        with open(nullact_txt, 'a') as f:
            if 'PredType' in pred:
                f.write(json.dumps(pred, ensure_ascii=False) + '\n')
                del pred['PredType']

        if inputs_type != 'frames':
            # Ensemble
            if label == 1 and 'Talk' in pred:
                tp_ens += 1
                pred['PredType'] = 'TP'
            elif label == 0 and 'Talk' in pred:
                fp_ens += 1
                pred['PredType'] = 'FP'
            elif label == 1 and 'Talk' not in pred:
                fn_ens += 1
                pred['PredType'] = 'FN'

            ensemble_txt = os.path.join(
                pred_dir, '{}_ensemble.txt'.format(vid))
            with open(ensemble_txt, 'a') as f:
                if 'PredType' in pred:
                    f.write(json.dumps(pred, ensure_ascii=False) + '\n')
                    del pred['PredType']

            # Trigger
            for th in th_lst:
                if label == 1 and res.trigger_pred > th:
                    tp_dict[th] += 1
                    pred['PredType'] = 'TP'
                elif label == 0 and res.trigger_pred > th:
                    fp_dict[th] += 1
                    pred['PredType'] = 'FP'
                elif label == 1 and res.trigger_pred < th:
                    fn_dict[th] += 1
                    pred['PredType'] = 'FN'

                trigger_txt = os.path.join(
                    pred_dir, '{}_trigger_th{:.2f}.txt'.format(vid, th))
                with open(trigger_txt, 'a') as f:
                    if 'PredType' in pred:
                        f.write(json.dumps(pred, ensure_ascii=False) + '\n')
                        del pred['PredType']

        with open(metric_log, 'w') as f:
            precision = (tp + 1e-8) / (tp + fp + 1e-8)
            recall = (tp + 1e-8) / (tp + fn + 1e-8)
            f1 = 2 * precision * recall / (precision + recall)
            f.write('#NullAct\n')
            f.write('TP FP FN P R F1\n')
            f.write('{} {} {} {:.4f} {:.4f} {:.4f}\n'.format(
                tp, fp, fn, precision, recall, f1))

            if inputs_type != 'frames':
                precision = (tp_ens + 1e-8) / (tp_ens + fp_ens + 1e-8)
                recall = (tp_ens + 1e-8) / (tp_ens + fn_ens + 1e-8)
                f1 = 2 * precision * recall / (precision + recall)
                f.write('#Ensemble\n')
                f.write('TP FP FN P R F1\n')
                f.write('{} {} {} {:.4f} {:.4f} {:.4f}\n'.format(
                    tp_ens, fp_ens, fn_ens, precision, recall, f1))

                f.write('#Trigger\n')
                f.write('TH P R\n')
                precision_dict, recall_dict = dict(), dict()
                for th in th_lst:
                    precision_dict[th] = (tp_dict[th] + 1e-8) / \
                        (tp_dict[th] + fp_dict[th] + 1e-8)
                    recall_dict[th] = (tp_dict[th] + 1e-8) / \
                        (tp_dict[th] + fn_dict[th] + 1e-8)
                    f1 = 2 * precision_dict[th] * recall_dict[th] / \
                        (precision_dict[th] + recall_dict[th])
                    f.write('{:.2f} {:.4f} {:.4f} {:.4f}\n'.format(
                        th, precision_dict[th], recall_dict[th], f1))

                ap = calculate_avg_precison(
                    [precision_dict[th] for th in th_lst])
                ar = calculate_avg_recall(
                    [recall_dict[th] for th in th_lst])
                f.write('AP AR\n')
                f.write('{} {}\n'.format(ap, ar))


def parallel_eval_from_videos(args, gpus):
    server_threads = []
    for idx in range(len(Config.keys())):
        s = threading.Thread(
            target=run_eval_server,
            args=(idx, gpus[idx % len(gpus)], args))
        s.setDaemon = True
        s.start()
        time.sleep(5)
        server_threads.append(s)

    time.sleep(30)  # wait all gRPC server initialized

    shared_Q = Queue(100)
    data_threads = []
    for idx in range(args.num_data_threads):
        d = threading.Thread(
            target=run_data_worker,
            args=(idx, args, shared_Q))
        d.setDaemon = True
        d.start()
        data_threads.append(d)

    distri_Qs = []
    for idx in range(len(Config.keys())):
        distri_Qs.append(Queue(100))
        e = threading.Thread(
            target=run_eval_worker,
            args=(idx, args, distri_Qs[idx]))
        e.setDaemon = True
        e.start()

    while True:
        data = shared_Q.get()
        for Q in distri_Qs:
            Q.put(data)


if __name__ == '__main__':
    if len(sys.argv) == 1:
        sys.argv.append('-h')
    args = parse_args()
    paths = split_eval_txt(args.eval_txt, args.workers)
    gpus = args.gpus.split(',')

    os.makedirs(args.logdir, exist_ok=True)

    if args.eval_mode == 'eval_txt':
        parallel_eval_from_txt(args, paths, gpus)
    elif args.eval_mode == 'eval_video':
        parallel_eval_from_videos(args, gpus)
